#include <assert.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>

#include "dsp/dma_utils.h"
#include "dsp/hmx_mgr.h"
#include "dsp/hmx_utils.h"
#include "dsp/hvx_convert.h"
#include "dsp/hvx_internal.h"
#include "dsp/quants.h"
#include "dsp/utils.h"
#include "dsp/vtcm_mgr.h"
#include "dsp/worker_pool.h"

// debug & profile
#include "HAP_farf.h"
#include "HAP_perf.h"

#define WEIGHT_AREA_SIZE     (1 * 1024 * 1024)
#define ACTIVATION_AREA_SIZE (1 * 1024 * 1024)
#define OUTPUT_AREA_SIZE     (1 * 1024 * 1024)
#define SCRATCH_AREA_SIZE    (1 * 1024 * 1024)

static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
  -8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0,
};

static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
  -127, 0, -104, 0, -83, 0, -65, 0, -49, 0, -35, 0, -22, 0, -10, 0,
  1,    0, 13,   0, 25,  0, 38,  0, 53,  0, 69,  0, 89,  0, 113, 0,
};

static const uint32_t common_layout_vscatter_offsets_base[32] __attribute__((aligned(VLEN))) = {
  0, 128, 256, 384, 512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536, 1664, 1792, 1920,
};

static inline void swap_ptr(void **p1, void **p2) {
  void *t = *p1;
  *p1     = *p2;
  *p2     = t;
}

static inline size_t get_super_block_size(enum ggml_type weight_type) {
  switch (weight_type) {
    case GGML_TYPE_Q4_0:
    case GGML_TYPE_IQ4_NL:
      return sizeof(my_block_q4_0);
    case GGML_TYPE_Q8_0:
      return sizeof(my_block_q8_0);
    default:
      return 0;
  }
}

static inline int dma_issue_load_from_ddr(dma_desc_1d_t *desc, void *vtcm_dst, const void *src, size_t size) {
  dma_wait_for_idle();

  desc->next       = 0;
  desc->length     = size;
  desc->type       = DMA_DESC_TYPE_1D;
  desc->src_bypass = 1;
  desc->dst_bypass = 0;
  desc->ordered    = 1;
  desc->dstate     = DMA_DESC_DSTATE_PENDING;
  desc->src        = (uint32_t) src;
  desc->dst        = (uint32_t) vtcm_dst;

  return dma_submit_one(desc);
}

static void find_chunk_size(size_t x_max, size_t y_max, size_t xy_max, size_t x_unit, size_t y_unit, size_t *x_out,
                            size_t *y_out) {
  int64_t best_xy = 0;
  size_t  best_x = 0, best_y = 0;

  for (size_t x = x_max; x > 0; x -= x_unit) {
    size_t  y  = smin(align_down(xy_max / x, y_unit), y_max);
    int64_t xy = x * y;
    if (best_xy < xy) {
      best_xy = xy;
      best_x = x, best_y = y;
    }
  }
  *x_out = best_x, *y_out = best_y;
}

// TODO(hzx): current implementation only use one thread. Use multiple threads to improve prefill performance
static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows,
                                                   int k_block, int k_stride) {
  assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0);
  assert(VLEN == 32 * sizeof(float));

  for (int r = 0; r < n_rows; r += 2) {
    int prefetch_row_idx = r + 2;
    if (prefetch_row_idx < n_rows) {
      const float *prefetch_addr = src + prefetch_row_idx * k_stride;
      // NOTE(hzx): prefetch 2 rows at a time
      l2fetch(prefetch_addr, k_stride * sizeof(float), k_block * sizeof(float), 2, 0);
    }

    int r0 = r / HMX_FP16_TILE_N_ROWS;  // tile row index
    int r1 = r % HMX_FP16_TILE_N_ROWS;  // intra-tile row idx

    const bool next_row_valid = (r + 1) < n_rows;

    const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride);
    const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride);
    for (int c = 0; c < k_block; c += 32) {
      HVX_Vector v0 = *pv_in0++;
      HVX_Vector v1 = next_row_valid ? *pv_in1++ : Q6_V_vzero();

      HVX_Vector v_out = hvx_my_wsf_to_vhf(v1, v0);

      // compute output position
      int c0       = c / HMX_FP16_TILE_N_COLS;  // tile column index
      int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0;

      HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS);
      tile[r1 / 2]     = v_out;
    }
  }
}

typedef struct {
  EXPAND_COMMON_TASK_STATE_MEMBERS
  int           k;
  __fp16       *dst;
  const __fp16 *src;
} permuted_weight_transfer_fp16_task_state_t;

static void transfer_permuted_weight_fp16_task(__fp16 *restrict vtcm_dst, const __fp16 *restrict src, int k,
                                               int n_col_tiles) {
  // transfer logical K*(32*n_col_tiles) weight block, direct copy, no extra computation
  size_t size   = k * n_col_tiles * HMX_FP16_TILE_N_COLS * sizeof(__fp16);
  int    n_vecs = size / VLEN;

  const size_t PREFETCH_SIZE   = 4096;
  const int    PREFETCH_N_VECS = PREFETCH_SIZE / VLEN;

  const HVX_Vector *pv_in  = (const HVX_Vector *) src;
  HVX_Vector       *pv_out = (HVX_Vector *) vtcm_dst;

  for (int i = 0; i < n_vecs; ++i) {
    if (i % PREFETCH_N_VECS == 0) {
      int prefetch_idx = i + PREFETCH_N_VECS;
      if (prefetch_idx < n_vecs) {
        size_t prefetch_n_vecs = smin(n_vecs - prefetch_idx, PREFETCH_N_VECS);
        l2fetch(pv_in + PREFETCH_N_VECS, VLEN, VLEN, prefetch_n_vecs, 0);
      }
    }

    *pv_out++ = *pv_in++;
  }
}

static void transfer_permuted_weight_fp16_worker_loop(void *data, int _worker_index) {
  (void) _worker_index;
  permuted_weight_transfer_fp16_task_state_t *state = (permuted_weight_transfer_fp16_task_state_t *) data;

  int k = state->k;

  while (1) {
    unsigned int task_id = worker_pool_atomic_inc_return(&(state->task_id)) - 1;
    if (task_id >= state->n_tasks) {
      break;
    }

    int    chunk_idx  = task_id * state->n_chunks_per_task;
    size_t chunk_size = smin(state->n_tot_chunks - chunk_idx, state->n_chunks_per_task);

    int           c        = chunk_idx * HMX_FP16_TILE_N_COLS;
    __fp16       *vtcm_dst = state->dst + c * k;
    const __fp16 *src      = state->src + c * k;
    transfer_permuted_weight_fp16_task(vtcm_dst, src, k, chunk_size);
  }

  worker_pool_synctoken_jobdone(&(state->sync_ctx));
}

static void transfer_permuted_weight_chunk_fp16(__fp16 *vtcm_dst, const __fp16 *src, int n_cols, int k) {
  // NOTE(hzx): weight matrix is already transposed. n_cols actually turns into n_rows
  assert(n_cols % HMX_FP16_TILE_N_COLS == 0);

  const bool use_dma = true;

  if (use_dma) {
    size_t size = n_cols * k * sizeof(__fp16);

    dma_desc_1d_t desc;
    dma_issue_load_from_ddr(&desc, vtcm_dst, src, size);
    dma_wait_for_idle();

    return;
  }

  int    n_workers         = num_hvx128_contexts;
  size_t n_tot_chunks      = n_cols / HMX_FP16_TILE_N_COLS;
  size_t n_chunks_per_task = ceil_div(n_tot_chunks, n_workers);
  // size_t n_chunks_per_task = 1;

  permuted_weight_transfer_fp16_task_state_t state;
  INIT_COMMON_TASK_STATE_MEMBERS(state, n_tot_chunks, n_chunks_per_task);
  state.k   = k;
  state.dst = vtcm_dst;
  state.src = src;

  worker_pool_job_t job;
  job.fptr = transfer_permuted_weight_fp16_worker_loop;
  job.dptr = &state;

  worker_pool_synctoken_init(&(state.sync_ctx), n_workers);
  for (int i = 0; i < n_workers; ++i) {
    worker_pool_submit(NULL, job);  // use default worker pool
  }
  worker_pool_synctoken_wait(&(state.sync_ctx));
}

typedef struct {
  EXPAND_COMMON_TASK_STATE_MEMBERS
  // NOTE: n_tot_chunks = number of total super-blocks
  __fp16        *dst;
  const void    *src;
  enum ggml_type quant_type;
  bool           src_in_vtcm;
  int            k;  // NOTE(hzx): only used in non-pre-permuted (common) weight case
} permuted_weight_dequantize_qk_0_hvx_task_state_t;

#define EXPAND_QK_0_VEC_SCALES_COMPUTATION(blk, vs0_c, vs1_c, vs2_c, vs3_c) \
  do {                                                                      \
    __fp16 s0 = blk.scales[0];                                              \
    __fp16 s1 = blk.scales[1];                                              \
    __fp16 s2 = blk.scales[2];                                              \
    __fp16 s3 = blk.scales[3];                                              \
    __fp16 s4 = blk.scales[4];                                              \
    __fp16 s5 = blk.scales[5];                                              \
    __fp16 s6 = blk.scales[6];                                              \
    __fp16 s7 = blk.scales[7];                                              \
                                                                            \
    HVX_Vector vs0 = Q6_Vh_vsplat_R(fp16_to_bits(&s0));                     \
    HVX_Vector vs1 = Q6_Vh_vsplat_R(fp16_to_bits(&s1));                     \
    HVX_Vector vs2 = Q6_Vh_vsplat_R(fp16_to_bits(&s2));                     \
    HVX_Vector vs3 = Q6_Vh_vsplat_R(fp16_to_bits(&s3));                     \
    HVX_Vector vs4 = Q6_Vh_vsplat_R(fp16_to_bits(&s4));                     \
    HVX_Vector vs5 = Q6_Vh_vsplat_R(fp16_to_bits(&s5));                     \
    HVX_Vector vs6 = Q6_Vh_vsplat_R(fp16_to_bits(&s6));                     \
    HVX_Vector vs7 = Q6_Vh_vsplat_R(fp16_to_bits(&s7));                     \
                                                                            \
    vs0_c = Q6_V_valign_VVR(vs1, vs0, 64);                                  \
    vs1_c = Q6_V_valign_VVR(vs3, vs2, 64);                                  \
    vs2_c = Q6_V_valign_VVR(vs5, vs4, 64);                                  \
    vs3_c = Q6_V_valign_VVR(vs7, vs6, 64);                                  \
  } while (0)

static inline HVX_Vector dequantize_single_q4_0_group(const block_q4_0 *group, const HVX_Vector vlut_cvt) {
  HVX_Vector vq = vmemu(&(group->quants));
  HVX_Vector vs = vmemu(&(group->scale));

  HVX_Vector v_scales = Q6_V_lo_W(Q6_Wh_vlut16_VbVhR_nomatch(Q6_V_vzero(), vs, 0));

  HVX_Vector v_qs_lo = vq;
  HVX_Vector v_qs_hi = Q6_Vub_vlsr_VubR(vq, 4);

  // concat lo & hi --> 32 elements in a group
  HVX_Vector v_lo_rot = Q6_V_vror_VR(v_qs_lo, 16);
  HVX_Vector v_quants = Q6_V_vlalign_VVR(v_qs_hi, v_lo_rot, 16);

  // convert INT4 -> FP16
  HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR_nomatch(v_quants, vlut_cvt, 0);

  HVX_Vector v_group_hf = Q6_V_lo_W(Q6_W_vshuff_VVR(Q6_V_hi_W(vp), Q6_V_lo_W(vp), -2));

  // // convert INT4 -> FP16
  // HVX_VectorPair vp_q0 = Q6_Wh_vlut16_VbVhR_nomatch(v_qs_lo, vlut_cvt, 0);
  // HVX_VectorPair vp_q1 = Q6_Wh_vlut16_VbVhR_nomatch(v_qs_hi, vlut_cvt, 0);

  // HVX_Vector v0 = Q6_V_lo_W(Q6_W_vshuff_VVR(Q6_V_hi_W(vp_q0), Q6_V_lo_W(vp_q0), -2));  // 16 valid elements
  // HVX_Vector v1 = Q6_V_lo_W(Q6_W_vshuff_VVR(Q6_V_hi_W(vp_q1), Q6_V_hi_W(vp_q1), -2));  // 16 valid elements

  // // concat v0 & v1 --> 32 elements in a group
  // HVX_Vector v0_rot     = Q6_V_vror_VR(v0, 32);
  // HVX_Vector v_group_hf = Q6_V_vlalign_VVR(v1, v0_rot, 32);

  // dequantize: quants(FP16) * values(FP16)
  v_group_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_group_hf, v_scales));
  return v_group_hf;
}

static inline HVX_Vector dequantize_single_q8_0_group(const block_q8_0 *group) {
  HVX_Vector vq = vmemu(&(group->quants));
  HVX_Vector vs = vmemu(&(group->scale));

  HVX_Vector v_scales = Q6_V_lo_W(Q6_Wh_vlut16_VbVhR_nomatch(Q6_V_vzero(), vs, 0));

  HVX_Vector v0         = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq));
  HVX_Vector v_group_hf = Q6_Vhf_equals_Vh(v0);

  // dequantize: quants(FP16) * values(FP16)
  v_group_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_group_hf, v_scales));
  return v_group_hf;
}

void dequantize_permuted_weight_q4_0_to_fp16_hvx_task(__fp16 *restrict vtcm_dst, const my_block_q4_0 *restrict src,
                                                      int n_blocks, bool src_in_vtcm, bool is_iq4_nl) {
  const int L2_PREFETCH_N_BLOCKS = 32;  // ~ 4K
  const int DC_PREFETCH_N_BLOCKS = 4;

  const bool no_group_coalesce = false;
  const bool no_dequantization = false;

  const HVX_Vector vlut_cvt = is_iq4_nl ? vmem(iq4_nl_to_fp16_lut) : vmem(q4_0_to_fp16_lut);

  static const uint8_t vlut_scales_idx_data[128] __attribute__((aligned(VLEN))) = {
    0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2,
    0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2,
    1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3,
    1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3,
  };
  const HVX_Vector vlut_scales_idx0 = vmem(vlut_scales_idx_data);
  const HVX_Vector vlut_scales_idx1 = Q6_Vb_vadd_VbVb(vlut_scales_idx0, Q6_Vb_vsplat_R(4));

  HVX_Vector *pv_out = (HVX_Vector *) vtcm_dst;

  for (int i = 0; i < n_blocks; ++i) {
    if (!src_in_vtcm && false) {
      if (i % L2_PREFETCH_N_BLOCKS == 0) {
        int prefetch_idx = i + L2_PREFETCH_N_BLOCKS;
        if (prefetch_idx < n_blocks) {
          size_t prefetch_n_blocks = smin(n_blocks - prefetch_idx, L2_PREFETCH_N_BLOCKS);
          l2fetch(src + prefetch_idx, sizeof(my_block_q4_0), sizeof(my_block_q4_0), prefetch_n_blocks, 0);
        }
      }

      if (i + DC_PREFETCH_N_BLOCKS < n_blocks) {
        Q6_dcfetch_A((void *) &(src[i + DC_PREFETCH_N_BLOCKS].scales));
      }
    }

    if (no_group_coalesce) {
      const block_q4_0 *groups = (const block_q4_0 *) (src + i);

      HVX_Vector v_g0 = dequantize_single_q4_0_group(groups + 0, vlut_cvt);
      HVX_Vector v_g1 = dequantize_single_q4_0_group(groups + 1, vlut_cvt);
      HVX_Vector v_g2 = dequantize_single_q4_0_group(groups + 2, vlut_cvt);
      HVX_Vector v_g3 = dequantize_single_q4_0_group(groups + 3, vlut_cvt);
      HVX_Vector v_g4 = dequantize_single_q4_0_group(groups + 4, vlut_cvt);
      HVX_Vector v_g5 = dequantize_single_q4_0_group(groups + 5, vlut_cvt);
      HVX_Vector v_g6 = dequantize_single_q4_0_group(groups + 6, vlut_cvt);
      HVX_Vector v_g7 = dequantize_single_q4_0_group(groups + 7, vlut_cvt);

      HVX_Vector v_g0_rot = Q6_V_vror_VR(v_g0, 64);
      HVX_Vector v_g2_rot = Q6_V_vror_VR(v_g2, 64);
      HVX_Vector v_g4_rot = Q6_V_vror_VR(v_g4, 64);
      HVX_Vector v_g6_rot = Q6_V_vror_VR(v_g6, 64);

      HVX_Vector v0 = Q6_V_vlalign_VVR(v_g1, v_g0_rot, 64);
      HVX_Vector v1 = Q6_V_vlalign_VVR(v_g3, v_g2_rot, 64);
      HVX_Vector v2 = Q6_V_vlalign_VVR(v_g5, v_g4_rot, 64);
      HVX_Vector v3 = Q6_V_vlalign_VVR(v_g7, v_g6_rot, 64);

      *pv_out++ = v0;
      *pv_out++ = v1;
      *pv_out++ = v2;
      *pv_out++ = v3;
      continue;
    }

    HVX_Vector qs = vmemu(src[i].quants);

    if (no_dequantization) {
      *pv_out++ = qs;
      *pv_out++ = qs;
      *pv_out++ = qs;
      *pv_out++ = qs;
      continue;
    }

    HVX_Vector v_qs_lo = qs;  // no need to mask out high 4 bits in each byte since vlut will do that for us
    HVX_Vector v_qs_hi = Q6_Vub_vlsr_VubR(qs, 4);

    HVX_VectorPair vp_q0 = Q6_Wh_vlut16_VbVhR_nomatch(v_qs_lo, vlut_cvt, 0);
    HVX_VectorPair vp_q1 = Q6_Wh_vlut16_VbVhR_nomatch(v_qs_hi, vlut_cvt, 0);

    // NOTE(hzx): the previous scalar->vector scales implementation is faster when src resides in DDR memory
    // HVX_Vector vs0_c, vs1_c, vs2_c, vs3_c;
    // EXPAND_QK_0_VEC_SCALES_COMPUTATION(src[i], vs0_c, vs1_c, vs2_c, vs3_c);

    HVX_Vector v_packed_scales = vmemu(src[i].scales);
    HVX_Vector vlut_scales     = Q6_V_lo_W(Q6_Wuw_vunpack_Vuh(v_packed_scales));

    HVX_VectorPair vp_s0 = Q6_Wh_vlut16_VbVhR_nomatch(vlut_scales_idx0, vlut_scales, 0);
    HVX_VectorPair vp_s1 = Q6_Wh_vlut16_VbVhR_nomatch(vlut_scales_idx1, vlut_scales, 0);

    HVX_Vector vs0_c = Q6_V_lo_W(vp_s0), vs1_c = Q6_V_hi_W(vp_s0);
    HVX_Vector vs2_c = Q6_V_lo_W(vp_s1), vs3_c = Q6_V_hi_W(vp_s1);

    *pv_out++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_q0), vs0_c));
    *pv_out++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_q0), vs1_c));
    *pv_out++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_q1), vs2_c));
    *pv_out++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_q1), vs3_c));
  }
}

void dequantize_permuted_weight_q8_0_to_fp16_hvx_task(__fp16 *restrict vtcm_dst, const my_block_q8_0 *restrict src,
                                                      int n_blocks, bool src_in_vtcm) {
  const int L2_PREFETCH_N_BLOCKS = 16;  // ~ 4K
  const int DC_PREFETCH_N_BLOCKS = 4;

  const bool no_group_coalesce = false;
  const bool no_dequantization = false;

  static const uint8_t vlut_scales_idx_data[128] __attribute__((aligned(VLEN))) = {
    0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2,
    0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2,
    1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3,
    1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3,
  };
  const HVX_Vector vlut_scales_idx0 = vmem(vlut_scales_idx_data);
  const HVX_Vector vlut_scales_idx1 = Q6_Vb_vadd_VbVb(vlut_scales_idx0, Q6_Vb_vsplat_R(4));

  HVX_Vector *pv_out = (HVX_Vector *) vtcm_dst;

  for (int i = 0; i < n_blocks; ++i) {
    if (!src_in_vtcm && false) {
      if (i % L2_PREFETCH_N_BLOCKS == 0) {
        int prefetch_idx = i + L2_PREFETCH_N_BLOCKS;
        if (prefetch_idx < n_blocks) {
          size_t prefetch_n_blocks = smin(n_blocks - prefetch_idx, L2_PREFETCH_N_BLOCKS);
          l2fetch(src + prefetch_idx, sizeof(my_block_q8_0), sizeof(my_block_q8_0), prefetch_n_blocks, 0);
        }
      }

      if (i + DC_PREFETCH_N_BLOCKS < n_blocks) {
        Q6_dcfetch_A((void *) &(src[i + DC_PREFETCH_N_BLOCKS].scales));
      }
    }

    if (no_group_coalesce) {
      const block_q8_0 *groups = (const block_q8_0 *) (src + i);

      HVX_Vector v_g0 = dequantize_single_q8_0_group(groups + 0);
      HVX_Vector v_g1 = dequantize_single_q8_0_group(groups + 1);
      HVX_Vector v_g2 = dequantize_single_q8_0_group(groups + 2);
      HVX_Vector v_g3 = dequantize_single_q8_0_group(groups + 3);
      HVX_Vector v_g4 = dequantize_single_q8_0_group(groups + 4);
      HVX_Vector v_g5 = dequantize_single_q8_0_group(groups + 5);
      HVX_Vector v_g6 = dequantize_single_q8_0_group(groups + 6);
      HVX_Vector v_g7 = dequantize_single_q8_0_group(groups + 7);

      HVX_Vector v_g0_rot = Q6_V_vror_VR(v_g0, 64);
      HVX_Vector v_g2_rot = Q6_V_vror_VR(v_g2, 64);
      HVX_Vector v_g4_rot = Q6_V_vror_VR(v_g4, 64);
      HVX_Vector v_g6_rot = Q6_V_vror_VR(v_g6, 64);

      HVX_Vector v0 = Q6_V_vlalign_VVR(v_g1, v_g0_rot, 64);
      HVX_Vector v1 = Q6_V_vlalign_VVR(v_g3, v_g2_rot, 64);
      HVX_Vector v2 = Q6_V_vlalign_VVR(v_g5, v_g4_rot, 64);
      HVX_Vector v3 = Q6_V_vlalign_VVR(v_g7, v_g6_rot, 64);

      *pv_out++ = v0;
      *pv_out++ = v1;
      *pv_out++ = v2;
      *pv_out++ = v3;
      continue;
    }

    HVX_Vector vq0 = vmemu(src[i].quants);
    HVX_Vector vq1 = vmemu(src[i].quants + VLEN);

    if (no_dequantization) {
      *pv_out++ = vq0;
      *pv_out++ = vq0;
      *pv_out++ = vq1;
      *pv_out++ = vq1;
      continue;
    }

    HVX_VectorPair vp0 = Q6_Wh_vunpack_Vb(vq0);
    HVX_VectorPair vp1 = Q6_Wh_vunpack_Vb(vq1);

    HVX_Vector v0 = Q6_Vhf_equals_Vh(Q6_V_lo_W(vp0));
    HVX_Vector v1 = Q6_Vhf_equals_Vh(Q6_V_hi_W(vp0));
    HVX_Vector v2 = Q6_Vhf_equals_Vh(Q6_V_lo_W(vp1));
    HVX_Vector v3 = Q6_Vhf_equals_Vh(Q6_V_hi_W(vp1));

    // HVX_Vector vs0_c, vs1_c, vs2_c, vs3_c;
    // EXPAND_QK_0_VEC_SCALES_COMPUTATION(src[i], vs0_c, vs1_c, vs2_c, vs3_c);

    HVX_Vector v_packed_scales = vmemu(src[i].scales);
    HVX_Vector vlut_scales     = Q6_V_lo_W(Q6_Wuw_vunpack_Vuh(v_packed_scales));

    HVX_VectorPair vp_s0 = Q6_Wh_vlut16_VbVhR_nomatch(vlut_scales_idx0, vlut_scales, 0);
    HVX_VectorPair vp_s1 = Q6_Wh_vlut16_VbVhR_nomatch(vlut_scales_idx1, vlut_scales, 0);

    HVX_Vector vs0_c = Q6_V_lo_W(vp_s0), vs1_c = Q6_V_hi_W(vp_s0);
    HVX_Vector vs2_c = Q6_V_lo_W(vp_s1), vs3_c = Q6_V_hi_W(vp_s1);

    *pv_out++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v0, vs0_c));
    *pv_out++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v1, vs1_c));
    *pv_out++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v2, vs2_c));
    *pv_out++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v3, vs3_c));
  }
}

static void dequantize_permuted_weight_qk_0_to_fp16_hvx_worker_loop(void *data, int _worker_index) {
  (void) _worker_index;
  permuted_weight_dequantize_qk_0_hvx_task_state_t *state = (permuted_weight_dequantize_qk_0_hvx_task_state_t *) data;

  while (1) {
    unsigned int task_id = worker_pool_atomic_inc_return(&(state->task_id)) - 1;
    if (task_id >= state->n_tasks) {
      break;
    }

    int    chunk_idx  = task_id * state->n_chunks_per_task;
    size_t chunk_size = smin(state->n_tot_chunks - chunk_idx, state->n_chunks_per_task);

    __fp16 *vtcm_dst = state->dst + chunk_idx * QK_K;

    if (state->quant_type == GGML_TYPE_Q4_0 || state->quant_type == GGML_TYPE_IQ4_NL) {
      const my_block_q4_0 *src = ((const my_block_q4_0 *) state->src) + chunk_idx;
      dequantize_permuted_weight_q4_0_to_fp16_hvx_task(vtcm_dst, src, chunk_size, state->src_in_vtcm,
                                                       state->quant_type == GGML_TYPE_IQ4_NL);
    } else if (state->quant_type == GGML_TYPE_Q8_0) {
      const my_block_q8_0 *src = ((const my_block_q8_0 *) state->src) + chunk_idx;
      dequantize_permuted_weight_q8_0_to_fp16_hvx_task(vtcm_dst, src, chunk_size, state->src_in_vtcm);
    }
  }

  worker_pool_synctoken_jobdone(&(state->sync_ctx));
}

void dequantize_permuted_weight_chunk_qk_0_to_fp16_hvx(__fp16 *vtcm_dst, const void *src, int ne, int k,
                                                       enum ggml_type type, void *vtcm_scratch) {
  assert(ne % QK_K == 0);
  (void) k;

  const bool src_in_vtcm = true;

  int    n_workers         = num_hvx128_contexts;
  size_t n_tot_chunks      = ne / QK_K;
  size_t n_chunks_per_task = ceil_div(n_tot_chunks, n_workers);

  permuted_weight_dequantize_qk_0_hvx_task_state_t state;
  INIT_COMMON_TASK_STATE_MEMBERS(state, n_tot_chunks, n_chunks_per_task);
  state.dst         = vtcm_dst;
  state.src         = src_in_vtcm ? vtcm_scratch : src;
  state.quant_type  = type;
  state.src_in_vtcm = src_in_vtcm;

  worker_pool_job_t job;
  job.fptr = dequantize_permuted_weight_qk_0_to_fp16_hvx_worker_loop;
  job.dptr = &state;

  // int64_t t0 = HAP_perf_get_qtimer_count();

  worker_pool_synctoken_init(&(state.sync_ctx), n_workers);
  for (int i = 0; i < n_workers; ++i) {
    worker_pool_submit(NULL, job);  // use default worker pool
  }
  worker_pool_synctoken_wait(&(state.sync_ctx));

  // int64_t e = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - t0);
  // FARF(ALWAYS, "QK_0 dequantize: ne: %d time: %lld us", ne, e);
}

void dequantize_common_weight_q4_0_to_fp16_hvx_task(__fp16 *restrict vtcm_dst, const block_q4_0 *restrict src,
                                                    int start_group_idx, int end_group_idx, int k, bool src_in_vtcm,
                                                    bool is_iq4_nl) {
  assert(src_in_vtcm);

  const size_t GROUP_SIZE = QK_0;
  assert(GROUP_SIZE == HMX_FP16_TILE_N_ROWS);

  const size_t N_GROUPS_PER_SCALAR_COLUMN = k / GROUP_SIZE;
  const size_t N_GROUPS_PER_TILE_COLUMN   = N_GROUPS_PER_SCALAR_COLUMN * HMX_FP16_TILE_N_COLS;

  const HVX_Vector vlut_cvt       = is_iq4_nl ? vmem(iq4_nl_to_fp16_lut) : vmem(q4_0_to_fp16_lut);
  const HVX_Vector v_offsets_base = vmem(common_layout_vscatter_offsets_base);

  const HVX_VectorPred q_32_elems_mask = Q6_Q_vsetq_R(32 * sizeof(__fp16));

  for (int g_idx = start_group_idx; g_idx < end_group_idx; ++g_idx) {
    const block_q4_0 *group = src + g_idx;

    HVX_Vector v_group_hf = dequantize_single_q4_0_group(group, vlut_cvt);

    // prepare for scatter
    int i0 = g_idx / N_GROUPS_PER_TILE_COLUMN;
    int i1 = g_idx % N_GROUPS_PER_TILE_COLUMN;

    int gr = i1 % N_GROUPS_PER_SCALAR_COLUMN;
    int gc = i1 / N_GROUPS_PER_SCALAR_COLUMN;

    int     tile_idx  = i0 * (k / HMX_FP16_TILE_N_ROWS) + gr;
    __fp16 *tile_base = vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS;

    HVX_Vector v_offsets_delta = Q6_V_vsplat_R(gc * 4);
    HVX_Vector v_offsets       = Q6_Vw_vadd_VwVw(v_offsets_base, v_offsets_delta);
    Q6_vscatter_QRMVwV(q_32_elems_mask, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_offsets, v_group_hf);
  }
}

void dequantize_common_weight_q8_0_to_fp16_hvx_task(__fp16 *restrict vtcm_dst, const block_q8_0 *restrict src,
                                                    int start_group_idx, int end_group_idx, int k, bool src_in_vtcm) {
  const size_t GROUP_SIZE = QK_0;
  assert(GROUP_SIZE == HMX_FP16_TILE_N_ROWS);

  const size_t N_GROUPS_PER_SCALAR_COLUMN = k / GROUP_SIZE;
  const size_t N_GROUPS_PER_TILE_COLUMN   = N_GROUPS_PER_SCALAR_COLUMN * HMX_FP16_TILE_N_COLS;

  const HVX_Vector v_offsets_base = vmem(common_layout_vscatter_offsets_base);

  const HVX_VectorPred q_32_elems_mask = Q6_Q_vsetq_R(32 * sizeof(__fp16));

  for (int g_idx = start_group_idx; g_idx < end_group_idx; ++g_idx) {
    const block_q8_0 *group = src + g_idx;

    HVX_Vector v_group_hf = dequantize_single_q8_0_group(group);

    // prepare for scatter
    int i0 = g_idx / N_GROUPS_PER_TILE_COLUMN;
    int i1 = g_idx % N_GROUPS_PER_TILE_COLUMN;

    int gr = i1 % N_GROUPS_PER_SCALAR_COLUMN;
    int gc = i1 / N_GROUPS_PER_SCALAR_COLUMN;

    int     tile_idx  = i0 * (k / HMX_FP16_TILE_N_ROWS) + gr;
    __fp16 *tile_base = vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS;

    HVX_Vector v_offsets_delta = Q6_V_vsplat_R(gc * 4);
    HVX_Vector v_offsets       = Q6_Vw_vadd_VwVw(v_offsets_base, v_offsets_delta);
    Q6_vscatter_QRMVwV(q_32_elems_mask, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_offsets, v_group_hf);
  }
}

static void dequantize_common_weight_chunk_qk_0_to_fp16_hvx_worker_loop(void *data, int _worker_index) {
  (void) _worker_index;
  permuted_weight_dequantize_qk_0_hvx_task_state_t *state = (permuted_weight_dequantize_qk_0_hvx_task_state_t *) data;

  const int     k        = state->k;
  __fp16 *const vtcm_dst = state->dst;

  while (1) {
    unsigned int task_id = worker_pool_atomic_inc_return(&(state->task_id)) - 1;
    if (task_id >= state->n_tasks) {
      break;
    }

    size_t start_idx = task_id * state->n_chunks_per_task;
    size_t end_idx   = smin(start_idx + state->n_chunks_per_task, state->n_tot_chunks);

    if (state->quant_type == GGML_TYPE_Q4_0 || state->quant_type == GGML_TYPE_IQ4_NL) {
      const block_q4_0 *src = (const block_q4_0 *) state->src;
      dequantize_common_weight_q4_0_to_fp16_hvx_task(vtcm_dst, src, start_idx, end_idx, k, state->src_in_vtcm,
                                                     state->quant_type == GGML_TYPE_IQ4_NL);
    } else if (state->quant_type == GGML_TYPE_Q8_0) {
      const block_q8_0 *src = (const block_q8_0 *) state->src;
      dequantize_common_weight_q8_0_to_fp16_hvx_task(vtcm_dst, src, start_idx, end_idx, k, state->src_in_vtcm);
    }
  }

  worker_pool_synctoken_jobdone(&(state->sync_ctx));
}

void dequantize_common_weight_chunk_qk_0_to_fp16_hvx(__fp16 *vtcm_dst, const void *src, int ne, int k,
                                                     enum ggml_type type, void *vtcm_scratch) {
  assert(ne % QK_0 == 0);
  assert(k % QK_0 == 0);

  const bool src_in_vtcm = true;

  int    n_workers         = num_hvx128_contexts;
  size_t n_tot_chunks      = ne / QK_0;
  size_t n_chunks_per_task = ceil_div(n_tot_chunks, n_workers);

  // NOTE: reuse the task state type for now
  permuted_weight_dequantize_qk_0_hvx_task_state_t state;
  INIT_COMMON_TASK_STATE_MEMBERS(state, n_tot_chunks, n_chunks_per_task);
  state.dst         = vtcm_dst;
  state.src         = src_in_vtcm ? vtcm_scratch : src;
  state.quant_type  = type;
  state.src_in_vtcm = src_in_vtcm;
  state.k           = k;

  worker_pool_job_t job;
  job.fptr = dequantize_common_weight_chunk_qk_0_to_fp16_hvx_worker_loop;
  job.dptr = &state;

  worker_pool_synctoken_init(&(state.sync_ctx), n_workers);
  for (int i = 0; i < n_workers; ++i) {
    worker_pool_submit(NULL, job);  // use default worker pool
  }
  worker_pool_synctoken_wait(&(state.sync_ctx));
}

static void core_dot_chunk_fp16(__fp16 *output, const __fp16 *activation, const __fp16 *weight, const __fp16 *scales,
                                int n_row_tiles, int n_col_tiles, int n_dot_tiles) {
  hmx_unit_acquire();

  asm volatile("mxclracc.hf");
  hmx_set_output_scales(scales);

  for (int r = 0; r < n_row_tiles; ++r) {
    for (int c = 0; c < n_col_tiles; ++c) {
      const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
      const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS;

      for (int k = 0; k < n_dot_tiles; k += 32) {
        int    offset  = k * HMX_FP16_TILE_N_ELMS;
        size_t n_tiles = smin(n_dot_tiles - k, 32);
        hmx_load_tiles_fp16(row_tiles + offset, col_tiles + offset, n_tiles);
      }

      __fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS;
      hmx_consume_accumulator_fp16(out_tile);
    }
  }

  hmx_unit_release();
}

// TODO(hzx): current implementation only use one thread. Use multiple threads to improve prefill performance
static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows,
                                               int n_cols, int n) {
  assert(n_cols % HMX_FP16_TILE_N_COLS == 0);
  const int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS;

  for (int r = 0; r < n_rows; r += 2) {
    int r0 = r / HMX_FP16_TILE_N_ROWS;
    int r1 = r % HMX_FP16_TILE_N_ROWS;

    for (int c = 0; c < n_cols; c += HMX_FP16_TILE_N_COLS) {
      int c0 = c / HMX_FP16_TILE_N_COLS;

      const __fp16 *tile = vtcm_src + (r0 * n_col_tiles + c0) * HMX_FP16_TILE_N_ELMS;

      HVX_Vector v_src = ((const HVX_Vector *) tile)[r1 / 2];

      HVX_VectorPair vp = hvx_my_vhf_to_wsf(v_src);

      HVX_Vector *pv_out0 = (HVX_Vector *) (dst + (r * n + c + 0));
      HVX_Vector *pv_out1 = (HVX_Vector *) (dst + (r * n + c + n));  // next row in global memory

      *pv_out0 = Q6_V_lo_W(vp);
      if (r + 1 < n_rows) {
        *pv_out1 = Q6_V_hi_W(vp);
      }
    }
  }
}

int hmx_mat_mul_permuted_w16a32(float *restrict dst, const float *restrict activation,
                                const __fp16 *restrict permuted_weight, int m, int k, int n) {
  if (!dst || !activation || !permuted_weight || !m || !n || !k) {
    return -1;
  }
  if (k % 32 != 0 || n % 32 != 0) {
    // TODO(hzx): can we remove this restriction?
    return -1;
  }
  if (!is_aligned(dst, VLEN) || !is_aligned(activation, VLEN) || !is_aligned(permuted_weight, VLEN)) {
    return -1;
  }

  const size_t weight_area_size     = WEIGHT_AREA_SIZE;
  const size_t activation_area_size = ACTIVATION_AREA_SIZE;
  const size_t output_area_size     = OUTPUT_AREA_SIZE;

  // VTCM layout: weight | activation | output | scales
  uint8_t *vtcm_ptr        = (uint8_t *) vtcm_manager_get_vtcm_base();
  __fp16  *vtcm_weight     = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size);
  __fp16  *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size);
  __fp16  *vtcm_output     = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size);
  __fp16  *vtcm_scales     = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256);

  hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00));  // fp16: 1.0

  size_t vec_dot_size       = k * sizeof(__fp16);
  size_t m_chunk_max_n_rows = align_down(activation_area_size / vec_dot_size, HMX_FP16_TILE_N_ROWS);
  size_t n_chunk_max_n_cols = align_down(weight_area_size / vec_dot_size, HMX_FP16_TILE_N_COLS);

  size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0;
  find_chunk_size(m_chunk_max_n_rows, n_chunk_max_n_cols, output_area_size / sizeof(__fp16), HMX_FP16_TILE_N_ROWS,
                  HMX_FP16_TILE_N_COLS, &m_chunk_n_rows, &n_chunk_n_cols);

  // FARF(ALWAYS, "computed chunk size: %d, %d", m_chunk_n_rows, n_chunk_n_cols);
  assert(m_chunk_n_rows > 0 && n_chunk_n_cols > 0);

  // int64_t activation_load_time, weight_load_time, hmx_core_time, output_store_time;
  // activation_load_time = weight_load_time = hmx_core_time = output_store_time = 0;

  for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
    // transfer activation matrix chunk into VTCM
    size_t n_rows = smin(m - mr, m_chunk_n_rows);

    // int64_t act_t0 = HAP_perf_get_qtimer_count();
    {
      const float *activation_chunk = activation + mr * k;
      transfer_activation_chunk_fp32_to_fp16(vtcm_activation, activation_chunk, n_rows, k, k);
    }
    // activation_load_time += HAP_perf_get_qtimer_count() - act_t0;

    // FARF(ALWAYS, "transfer activation ok, mr = %d, n_rows = %d", mr, n_rows);

    for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) {
      size_t n_cols = smin(n - nc, n_chunk_n_cols);

      // int64_t wei_t0 = HAP_perf_get_qtimer_count();
      {
        const __fp16 *permuted_weight_chunk = permuted_weight + nc * k;
        transfer_permuted_weight_chunk_fp16(vtcm_weight, permuted_weight_chunk, n_cols, k);
      }
      // weight_load_time += HAP_perf_get_qtimer_count() - wei_t0;

      // FARF(ALWAYS, "transfer weight ok, nc = %d, n_cols = %d", nc, n_cols);

      // int64_t core_t0 = HAP_perf_get_qtimer_count();
      {
        const int n_row_tiles = ceil_div(n_rows, HMX_FP16_TILE_N_ROWS);
        const int n_col_tiles = ceil_div(n_cols, HMX_FP16_TILE_N_COLS);
        core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32);
      }
      // hmx_core_time += HAP_perf_get_qtimer_count() - core_t0;

      // FARF(ALWAYS, "core compute ok, (%d, %d) tiles", n_row_tiles, n_col_tiles);

      // int64_t out_t0 = HAP_perf_get_qtimer_count();
      {
        float *output = dst + (mr * n + nc);
        transfer_output_chunk_fp16_to_fp32(output, vtcm_output, n_rows, n_cols, n);
      }
      // output_store_time += HAP_perf_get_qtimer_count() - out_t0;

      // FARF(ALWAYS, "transfer output ok, (%d, %d)", mr, nc);
    }
  }

  // FARF(ALWAYS, "%s: m = %d, k = %d, n = %d", __func__, m, k, n);
  // FARF(ALWAYS, "    activation load: %lld us", HAP_perf_qtimer_count_to_us(activation_load_time));
  // FARF(ALWAYS, "    weight     load: %lld us", HAP_perf_qtimer_count_to_us(weight_load_time));
  // FARF(ALWAYS, "    core     matmul: %lld us", HAP_perf_qtimer_count_to_us(hmx_core_time));
  // FARF(ALWAYS, "    output    store: %lld us", HAP_perf_qtimer_count_to_us(output_store_time));

  // size_t weight_size = k * n * sizeof(__fp16);
  // float  bandwidth   = 1e-3 * weight_size / HAP_perf_qtimer_count_to_us(weight_load_time);
  // FARF(ALWAYS, "    weight load bandwidth: %.2f GB/s", bandwidth);

  return 0;
}

extern worker_pool_context_t hmx_worker_pool_ctx;

typedef struct {
  __fp16            *c;
  const __fp16      *a, *b, *s;
  int                n_row_tiles, n_col_tiles, n_dot_tiles;
  worker_synctoken_t sync_ctx;
} core_dot_fp16_task_state_t;

static void core_dot_fp16_hmx_worker_fn(void *data, int _worker_index) {
  (void) _worker_index;
  core_dot_fp16_task_state_t *st = (core_dot_fp16_task_state_t *) data;

  core_dot_chunk_fp16(st->c, st->a, st->b, st->s, st->n_row_tiles, st->n_col_tiles, st->n_dot_tiles);

  worker_pool_synctoken_jobdone(&st->sync_ctx);
}

int mat_mul_qk_0_d16a32_out_stationary(float *restrict out, const float *restrict x, const uint8_t *restrict w, int m,
                                       int k, int n, enum ggml_type w_type);

int hmx_mat_mul_permuted_qk_0_d16a32(float *restrict dst, const float *restrict activation,
                                     const uint8_t *restrict permuted_weight, int m, int k, int n,
                                     enum ggml_type weight_type) {
  if (!dst || !activation || !permuted_weight || !m || !n || !k) {
    return -1;
  }
  if (k % 32 != 0 || n % 32 != 0) {
    // TODO(hzx): can we remove this restriction?
    return -1;
  }
  if (!is_aligned(dst, VLEN) || !is_aligned(activation, VLEN) || !is_aligned(permuted_weight, VLEN)) {
    return -1;
  }

  // for large m, k (e.g. prefill FFN Down), use out-staionary version
  if (m >= 128 && k > n && n > 1024) {
    return mat_mul_qk_0_d16a32_out_stationary(dst, activation, permuted_weight, m, k, n, weight_type);
  }

  size_t super_block_size = get_super_block_size(weight_type);
  if (super_block_size == 0) {
    return -1;
  }

  const size_t weight_area_size     = WEIGHT_AREA_SIZE;
  const size_t activation_area_size = ACTIVATION_AREA_SIZE;
  const size_t output_area_size     = OUTPUT_AREA_SIZE;
  const size_t scratch_area_size    = SCRATCH_AREA_SIZE;

  // VTCM layout: weight | activation | output | scales
  uint8_t *vtcm_ptr        = (uint8_t *) vtcm_manager_get_vtcm_base();
  __fp16  *vtcm_weight     = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size);
  __fp16  *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size);
  __fp16  *vtcm_output     = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size);
  void    *vtcm_scratch0   = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size);
  void    *vtcm_scratch1   = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size);
  void    *vtcm_scratch2   = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size);
  __fp16  *vtcm_scales     = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256);

  hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00));  // fp16: 1.0

  size_t vec_dot_size       = k * sizeof(__fp16);
  size_t m_chunk_max_n_rows = align_down(activation_area_size / vec_dot_size, HMX_FP16_TILE_N_ROWS);
  size_t n_chunk_max_n_cols = align_down(weight_area_size / vec_dot_size, HMX_FP16_TILE_N_COLS);

  size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0;
  find_chunk_size(m_chunk_max_n_rows, n_chunk_max_n_cols, output_area_size / sizeof(__fp16), HMX_FP16_TILE_N_ROWS,
                  HMX_FP16_TILE_N_COLS, &m_chunk_n_rows, &n_chunk_n_cols);

  // FARF(ALWAYS, "computed chunk size: %d, %d", m_chunk_n_rows, n_chunk_n_cols);
  assert(m_chunk_n_rows > 0 && n_chunk_n_cols > 0);

  // int64_t activation_load_time, weight_load_time, hmx_core_time, output_store_time;
  // activation_load_time = weight_load_time = hmx_core_time = output_store_time = 0;

  const bool use_pipeline = (m >= 128) && (k <= n);
  // const bool use_pipeline = false;

  if (!use_pipeline) {
    // NOTE(hzx): In this simple implementation, load-matmul-store are executed sequentially
    // only DMA load and dequantization process are overlapped during the load stage

    for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
      // transfer activation matrix chunk into VTCM
      size_t n_rows = smin(m - mr, m_chunk_n_rows);

      // int64_t act_t0 = HAP_perf_get_qtimer_count();
      {
        const float *activation_chunk = activation + mr * k;
        transfer_activation_chunk_fp32_to_fp16(vtcm_activation, activation_chunk, n_rows, k, k);
      }
      // activation_load_time += HAP_perf_get_qtimer_count() - act_t0;

      // FARF(ALWAYS, "transfer activation ok, mr = %d, n_rows = %d", mr, n_rows);

      void *buf_curr = vtcm_scratch0;
      void *buf_next = vtcm_scratch1;

      static dma_desc_1d_t desc
        __attribute__((aligned(64)));  // NOTE(hzx): make sure the DMA descriptor's lifetime is long enough

      // issue async DDR data transfer for the first weight chunk
      {
        const size_t n_cols_first            = smin(n, n_chunk_n_cols);
        const size_t first_weight_chunk_size = n_cols_first * k / QK_K * super_block_size;

        dma_issue_load_from_ddr(&desc, buf_curr, permuted_weight, first_weight_chunk_size);
      }

      for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) {
        size_t n_cols = smin(n - nc, n_chunk_n_cols);

        // int64_t wei_t0 = HAP_perf_get_qtimer_count();
        {
          dma_wait_for_idle();  // wait until current weight chunk become ready

          const size_t nc_next = nc + n_chunk_n_cols;
          if (nc_next < n) {
            const size_t n_cols_next = smin(n - nc_next, n_chunk_n_cols);

            const size_t   next_weight_chunk_size = n_cols_next * k / QK_K * super_block_size;
            const uint8_t *next_weight_chunk      = permuted_weight + nc_next * k / QK_K * super_block_size;

            dma_issue_load_from_ddr(&desc, buf_next, next_weight_chunk, next_weight_chunk_size);
          }

          const uint8_t *permuted_weight_chunk = permuted_weight + (nc * k / QK_K) * super_block_size;
          dequantize_permuted_weight_chunk_qk_0_to_fp16_hvx(vtcm_weight, permuted_weight_chunk, n_cols * k, k,
                                                            weight_type, buf_curr);

          swap_ptr(&buf_curr, &buf_next);
        }
        // weight_load_time += HAP_perf_get_qtimer_count() - wei_t0;

        // FARF(ALWAYS, "transfer weight ok, nc = %d, n_cols = %d", nc, n_cols);

        // int64_t core_t0 = HAP_perf_get_qtimer_count();
        {
          const int n_row_tiles = ceil_div(n_rows, HMX_FP16_TILE_N_ROWS);
          const int n_col_tiles = ceil_div(n_cols, HMX_FP16_TILE_N_COLS);
          core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32);
        }
        // hmx_core_time += HAP_perf_get_qtimer_count() - core_t0;

        // FARF(ALWAYS, "core compute ok, (%d, %d) tiles", n_row_tiles, n_col_tiles);

        // int64_t out_t0 = HAP_perf_get_qtimer_count();
        {
          float *output = dst + (mr * n + nc);
          transfer_output_chunk_fp16_to_fp32(output, vtcm_output, n_rows, n_cols, n);
        }
        // output_store_time += HAP_perf_get_qtimer_count() - out_t0;

        // FARF(ALWAYS, "transfer output ok, (%d, %d)", mr, nc);
      }
    }
  } else {
    // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D)
    // stage B and D (dequantize and store) are expected to be on the critical path

    // A --> B: vtcm_qweight, 1 buffer
    // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers
    // C --> D: vtcm_output0/vtcm_output1, 2 buffers

    //
    // LD ||A3|  | B3 ||
    // MM ||    C2    ||
    // ST || D1 |     ||

    static dma_desc_1d_t _Alignas(64) dma_desc;
    static core_dot_fp16_task_state_t mm_task_state;
    static worker_pool_job_t          mm_task_job;

    mm_task_job.dptr = &mm_task_state;
    mm_task_job.fptr = &core_dot_fp16_hmx_worker_fn;

    int n_chunk_cnt = ceil_div(n, n_chunk_n_cols);
    for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
      const size_t n_rows = smin(m - mr, m_chunk_n_rows);

      void *vtcm_qweight        = vtcm_weight;
      void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 };
      void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 };

      // prologue: A0
      const size_t n_cols_A0 = smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols);
      {
        const size_t chunk_size_A0 = n_cols_A0 * k / QK_K * super_block_size;

        const uint8_t *qweight_chunk_A0 = permuted_weight;
        dma_issue_load_from_ddr(&dma_desc, vtcm_qweight, qweight_chunk_A0, chunk_size_A0);
      }

      {
        const float *activation_chunk = activation + mr * k;
        transfer_activation_chunk_fp32_to_fp16(vtcm_activation, activation_chunk, n_rows, k, k);
      }

      // prologue: B0, A1, C0, B1
      {
        // B0
        dma_wait_for_idle();
        dequantize_permuted_weight_chunk_qk_0_to_fp16_hvx(vtcm_weight_bufs[0], NULL, n_cols_A0 * k, k, weight_type,
                                                          vtcm_qweight);

        // A1
        const size_t n_cols_A1 = smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols);
        if (1 < n_chunk_cnt) {
          const size_t chunk_size_A1 = n_cols_A1 * k / QK_K * super_block_size;

          const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * k / QK_K * super_block_size;
          dma_issue_load_from_ddr(&dma_desc, vtcm_qweight, qweight_chunk_A1, chunk_size_A1);
        }

        // C0
        {
          core_dot_fp16_task_state_t *s = &mm_task_state;

          s->c = (__fp16 *) vtcm_output_bufs[0];
          s->a = (__fp16 *) vtcm_activation;
          s->b = (__fp16 *) vtcm_weight_bufs[0];
          s->s = vtcm_scales;

          s->n_row_tiles = ceil_div(n_rows, HMX_FP16_TILE_N_ROWS);
          s->n_col_tiles = ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS);
          s->n_dot_tiles = k / HMX_FP16_TILE_N_ROWS;

          worker_pool_synctoken_init(&s->sync_ctx, 1);
          worker_pool_submit(hmx_worker_pool_ctx, mm_task_job);
        }

        // B1
        if (1 < n_chunk_cnt) {
          dma_wait_for_idle();
          dequantize_permuted_weight_chunk_qk_0_to_fp16_hvx(vtcm_weight_bufs[1], NULL, n_cols_A1 * k, k, weight_type,
                                                            vtcm_qweight);
        }
      }

      // main loop
      for (int i = 0; i < n_chunk_cnt; ++i) {
        const size_t nc    = i * n_chunk_n_cols;
        const size_t nc_p1 = nc + 1 * n_chunk_n_cols;
        const size_t nc_p2 = nc + 2 * n_chunk_n_cols;

        const size_t n_cols    = smin(n - nc, n_chunk_n_cols);
        const size_t n_cols_p1 = smin(n - nc_p1, n_chunk_n_cols);
        const size_t n_cols_p2 = smin(n - nc_p2, n_chunk_n_cols);

        // issue A_{i+2}
        if (i + 2 < n_chunk_cnt) {
          const size_t   chunk_size_p2    = n_cols_p2 * k / QK_K * super_block_size;
          const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * k / QK_K * super_block_size;
          dma_issue_load_from_ddr(&dma_desc, vtcm_qweight, qweight_chunk_p2, chunk_size_p2);
        }

        // wait for HMX (C_{i})
        worker_pool_synctoken_wait(&mm_task_state.sync_ctx);

        // result of B_{i+1} (input of C_{i+1}) should be ready now

        // issue C_{i+1}
        if (i + 1 < n_chunk_cnt) {
          core_dot_fp16_task_state_t *s = &mm_task_state;

          s->c = (__fp16 *) vtcm_output_bufs[(i + 1) % 2];
          s->a = (__fp16 *) vtcm_activation;
          s->b = (__fp16 *) vtcm_weight_bufs[(i + 1) % 2];
          s->s = vtcm_scales;

          s->n_row_tiles = ceil_div(n_rows, HMX_FP16_TILE_N_ROWS);
          s->n_col_tiles = ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS);
          s->n_dot_tiles = k / HMX_FP16_TILE_N_ROWS;

          worker_pool_synctoken_init(&s->sync_ctx, 1);
          worker_pool_submit(hmx_worker_pool_ctx, mm_task_job);
        }

        // compute D_{i}
        float *output_chunk = dst + (mr * n + nc);
        transfer_output_chunk_fp16_to_fp32(output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n);

        // wait for DMA (A_{i+2}), compute B_{i+2}
        if (i + 2 < n_chunk_cnt) {
          dma_wait_for_idle();
          dequantize_permuted_weight_chunk_qk_0_to_fp16_hvx(vtcm_weight_bufs[(i + 2) % 2], NULL, n_cols_p2 * k, k,
                                                            weight_type, vtcm_qweight);
        }
      }
    }
  }

  // FARF(ALWAYS, "%s: m = %d, k = %d, n = %d", __func__, m, k, n);
  // FARF(ALWAYS, "    activation load: %lld us", HAP_perf_qtimer_count_to_us(activation_load_time));
  // FARF(ALWAYS, "    weight     load: %lld us", HAP_perf_qtimer_count_to_us(weight_load_time));
  // FARF(ALWAYS, "    core     matmul: %lld us", HAP_perf_qtimer_count_to_us(hmx_core_time));
  // FARF(ALWAYS, "    output    store: %lld us", HAP_perf_qtimer_count_to_us(output_store_time));

  // size_t weight_size = k * n / QK_K * super_block_size;
  // float  bandwidth   = 1e-3 * weight_size / HAP_perf_qtimer_count_to_us(weight_load_time);
  // FARF(ALWAYS, "    weight load bandwidth: %.2f GB/s", bandwidth);

  return 0;
}

// C += AB
void core_mma_chunk_fp16(__fp16 *c, const __fp16 *a, const __fp16 *b, const __fp16 *col_scales, const __fp16 *eye_tile,
                         int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) {
  hmx_unit_acquire();

  asm volatile("mxclracc.hf");
  hmx_set_output_scales(col_scales);

  for (int i = 0; i < n_row_tiles; ++i) {
    for (int j = 0; j < n_col_tiles; ++j) {
      const __fp16 *row_tiles = a + i * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
      const __fp16 *col_tiles = b + j * n_dot_tiles * HMX_FP16_TILE_N_ELMS;

      __fp16 *accum_tile = c + (i * n_col_tiles + j) * HMX_FP16_TILE_N_ELMS;
      if (!zero_init) {
        hmx_load_tiles_fp16(accum_tile, eye_tile, 1);
      }

      for (int k = 0; k < n_dot_tiles; k += 32) {
        int    offset  = k * HMX_FP16_TILE_N_ELMS;
        size_t n_tiles = smin(n_dot_tiles - k, 32);
        hmx_load_tiles_fp16(row_tiles + offset, col_tiles + offset, n_tiles);
      }

      hmx_consume_accumulator_fp16(accum_tile);
    }
  }

  hmx_unit_release();
}

typedef struct {
  uint8_t           *dst;
  const uint8_t     *src;
  size_t             height;
  size_t             width;   // bytes
  size_t             stride;  // bytes
  worker_synctoken_t sync_ctx;
} qweight_fetch_task_state_t;

// Synchronously load 2D data from DDR to VTCM via DMA (using sequential 1D DMA transfers)
static void dma_load_2d_sync(uint8_t *dst, const uint8_t *src, size_t dst_stride, size_t src_stride, size_t height,
                             size_t width) {
  static dma_desc_1d_t desc __attribute__((aligned(64))) = { 0 };
  for (size_t i = 0; i < height; ++i) {
    size_t src_offset = i * src_stride;
    size_t dst_offset = i * dst_stride;
    dma_issue_load_from_ddr(&desc, dst + dst_offset, src + src_offset, width);
    dma_wait_for_idle();  // wait for the current row to be ready
  }
}

static void qweight_fetch_worker_fn(void *data, int _worker_index) {
  (void) _worker_index;
  qweight_fetch_task_state_t *st = (qweight_fetch_task_state_t *) data;

  dma_load_2d_sync(st->dst, st->src, st->width, st->stride, st->height, st->width);

  worker_pool_synctoken_jobdone(&st->sync_ctx);
}

// Only slightly faster than the common version (with L2 prefetch enabled) when doing VTCM to VTCM transfer
void transfer_activation_chunk_no_prefetch(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows,
                                           int k_block, int k_stride) {
  for (int r = 0; r < n_rows; r += 2) {
    int r0 = r / HMX_FP16_TILE_N_ROWS;  // tile row index
    int r1 = r % HMX_FP16_TILE_N_ROWS;  // intra-tile row idx

    const bool next_row_valid = (r + 1) < n_rows;

    const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride);
    const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride);
    for (int c = 0; c < k_block; c += 32) {
      HVX_Vector v0 = *pv_in0++;
      HVX_Vector v1 = next_row_valid ? *pv_in1++ : Q6_V_vzero();

      HVX_Vector v_out = hvx_my_wsf_to_vhf(v1, v0);

      // compute output position
      int c0       = c / HMX_FP16_TILE_N_COLS;  // tile column index
      int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0;

      HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS);
      tile[r1 / 2]     = v_out;
    }
  }
}

typedef struct {
  EXPAND_COMMON_TASK_STATE_MEMBERS
  __fp16      *dst;
  const float *src;
  int          k_block, k_stride;
} activation_transfer_task_state_t;

static void transfer_activation_chunk_worker_fn(void *data, int _worker_index) {
  (void) _worker_index;
  activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data;

  while (1) {
    unsigned int task_id = worker_pool_atomic_inc_return(&st->task_id) - 1;
    if (task_id >= st->n_tasks) {
      break;
    }
    // one chunk: one row
    int    chunk_idx  = task_id * st->n_chunks_per_task;
    size_t chunk_size = smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task);

    __fp16      *dst = st->dst + chunk_idx * st->k_block;
    const float *src = st->src + chunk_idx * st->k_stride;
    transfer_activation_chunk_no_prefetch(dst, src, chunk_size, st->k_block, st->k_stride);
  }

  worker_pool_synctoken_jobdone(&st->sync_ctx);
}

void transfer_activation_chunk_multithread(__fp16 *dst, const float *src, int n_rows, int k_block, int k_stride) {
  int    n_workers         = num_hvx128_contexts;
  size_t n_tot_chunks      = n_rows;
  size_t n_chunks_per_task = 32;  // NOTE(hzx): must be multiple of 32 to ensure correct destination address

  activation_transfer_task_state_t state;
  INIT_COMMON_TASK_STATE_MEMBERS(state, n_tot_chunks, n_chunks_per_task);
  state.dst      = dst;
  state.src      = src;
  state.k_block  = k_block;
  state.k_stride = k_stride;

  worker_pool_job_t job;
  job.dptr = &state;
  job.fptr = &transfer_activation_chunk_worker_fn;

  worker_pool_synctoken_init(&state.sync_ctx, n_workers);
  for (int i = 0; i < n_workers; ++i) {
    worker_pool_submit(NULL, job);  // use default worker pool
  }
  worker_pool_synctoken_wait(&state.sync_ctx);
}

int mat_mul_qk_0_d16a32_out_stationary(float *restrict out, const float *restrict x, const uint8_t *restrict w, int m,
                                       int k, int n, enum ggml_type weight_type) {
  // NOTE(hzx): this constraint on k originates from 2D DMA, consider alternative ways to load activation
  assert(k < 16384);
  // assume k % 32 == 0 && n % 32 == 0
  const size_t super_block_size = get_super_block_size(weight_type);
  if (super_block_size == 0) {
    return -1;
  }

  uint8_t *vtcm_ptr        = (uint8_t *) vtcm_manager_get_vtcm_base();
  __fp16  *vtcm_weight     = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, WEIGHT_AREA_SIZE);
  __fp16  *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, ACTIVATION_AREA_SIZE);
  __fp16  *vtcm_output     = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, OUTPUT_AREA_SIZE);
  uint8_t *vtcm_scratch0   = vtcm_seq_alloc(&vtcm_ptr, SCRATCH_AREA_SIZE);
  uint8_t *vtcm_scratch1   = vtcm_seq_alloc(&vtcm_ptr, SCRATCH_AREA_SIZE * 2);
  __fp16  *vtcm_eye_tile   = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, HMX_FP16_TILE_SIZE);
  __fp16  *vtcm_scales     = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256);

  // initialize eye tile (32x32 identity matrix)
  {
    HVX_Vector v;
    v = Q6_V_vzero();
    v = Q6_Vw_vinsert_VwR(v, 0x3c000000);
    v = Q6_V_vror_VR(v, VLEN - 4);
    v = Q6_Vw_vinsert_VwR(v, 0x00003c00);
    for (int i = 0; i < 16; ++i) {
      ((HVX_Vector *) vtcm_eye_tile)[i] = v;

      v = Q6_V_vror_VR(v, VLEN - 8);
    }
  }
  hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00));  // fp16: 1.0

  // 704, 512
  const size_t M_BLOCK_SIZE = 512;
  const size_t N_BLOCK_SIZE = 512;
  const size_t K_BLOCK_SIZE = 512;

  static worker_pool_context_t      fetch_task_worker_pool_ctx;
  static qweight_fetch_task_state_t fetch_task_state;
  static worker_pool_job_t          fetch_task_job;

  worker_pool_init_ex(&fetch_task_worker_pool_ctx, 4096, 1, 0);
  fetch_task_job.dptr = &fetch_task_state;
  fetch_task_job.fptr = &qweight_fetch_worker_fn;

  int64_t t_a, t_b, t_c;
  t_a = t_b = t_c = 0;

  for (size_t mr = 0; mr < m; mr += M_BLOCK_SIZE) {
    size_t m_blk_sz = smin(m - mr, M_BLOCK_SIZE);
    for (size_t nc = 0; nc < n; nc += N_BLOCK_SIZE) {
      size_t n_blk_sz = smin(n - nc, N_BLOCK_SIZE);

      const int n_row_tiles = ceil_div(m_blk_sz, HMX_FP16_TILE_N_ROWS);
      const int n_col_tiles = ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS);

      // TODO(hzx): fully pipelined loop
      for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) {
        size_t k_blk_sz = smin(k - kk, K_BLOCK_SIZE);

        int64_t t0 = HAP_perf_get_qtimer_count();
        // fetch activation block into VTCN
        {
          const float *activation_block = x + mr * k + kk;

          _Alignas(64) dma_desc_2d_t desc = { 0 };

          desc.next       = 0;
          desc.length     = 0;
          desc.type       = DMA_DESC_TYPE_2D;
          desc.src_bypass = 1;
          desc.dst_bypass = 0;
          desc.ordered    = 1;
          desc.dstate     = DMA_DESC_DSTATE_PENDING;

          desc.src              = (uint32_t) activation_block;
          desc.dst              = (uint32_t) vtcm_scratch1;
          desc.roi_width        = k_blk_sz * sizeof(float);
          desc.roi_height       = m_blk_sz;
          desc.src_stride       = k * sizeof(float);
          desc.dst_stride       = k_blk_sz * sizeof(float);
          desc.src_width_offset = 0;
          desc.dst_width_offset = 0;

          dma_wait_for_idle();
          dma_submit_one((dma_desc_1d_t *) &desc);
          dma_wait_for_idle();
        }

        // fetch weight block into VTCM
        {
          qweight_fetch_task_state_t *s = &fetch_task_state;

          size_t width_ne  = HMX_FP16_TILE_N_COLS * k_blk_sz;
          size_t stride_ne = HMX_FP16_TILE_N_COLS * k;
          size_t width     = width_ne / QK_K * super_block_size;
          size_t stride    = stride_ne / QK_K * super_block_size;

          s->dst    = vtcm_scratch0;
          s->src    = w + (nc * k + HMX_FP16_TILE_N_COLS * kk) / QK_K * super_block_size;
          s->height = n_col_tiles;
          s->width  = width;
          s->stride = stride;

          worker_pool_synctoken_init(&s->sync_ctx, 1);
          worker_pool_submit(fetch_task_worker_pool_ctx, fetch_task_job);
        }
        t_a += HAP_perf_get_qtimer_count() - t0;

        int64_t t1 = HAP_perf_get_qtimer_count();
        // load activation block
        {
          // const float *activation_block = x + mr * k + kk;
          // transfer_activation_chunk_fp32_to_fp16(vtcm_activation, activation_block, m_blk_sz, k_blk_sz, k);
          // transfer_activation_chunk_multithread(vtcm_activation, activation_block, m_blk_sz, k_blk_sz, k);

          // NOTE(hzx): This code assumes that the activation block already resides in VTCM
          // transfer_activation_chunk_no_prefetch(vtcm_activation, (float *) vtcm_scratch1, m_blk_sz, k_blk_sz, k_blk_sz);
          transfer_activation_chunk_multithread(vtcm_activation, (float *) vtcm_scratch1, m_blk_sz, k_blk_sz, k_blk_sz);
        }

        // dequantize weight block
        {
          // vtcm_scratch0 is used to store the qweight chunk
          worker_pool_synctoken_wait(&fetch_task_state.sync_ctx);
          dequantize_permuted_weight_chunk_qk_0_to_fp16_hvx(vtcm_weight, NULL /*unused*/, n_blk_sz * k_blk_sz,
                                                            -1 /*unused*/, weight_type, vtcm_scratch0);
        }
        t_b += HAP_perf_get_qtimer_count() - t1;

        // core mma
        int64_t t2 = HAP_perf_get_qtimer_count();
        {
          core_mma_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, vtcm_eye_tile, n_row_tiles,
                              n_col_tiles, k_blk_sz / HMX_FP16_TILE_N_COLS, kk == 0);
        }
        t_c += HAP_perf_get_qtimer_count() - t2;
      }

      // store output block
      {
        float *output_block = out + (mr * n + nc);
        transfer_output_chunk_fp16_to_fp32(output_block, vtcm_output, m_blk_sz, n_blk_sz, n);
      }
    }
  }

  FARF(ALWAYS, "t_a: %lld us, t_b: %lld us, t_c: %lld us", HAP_perf_qtimer_count_to_us(t_a),
       HAP_perf_qtimer_count_to_us(t_b), HAP_perf_qtimer_count_to_us(t_c));

  worker_pool_deinit(&fetch_task_worker_pool_ctx);
  return 0;
}
